{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "aa2e1265-86b9-4d06-9e96-1f12a8167a33",
   "metadata": {},
   "outputs": [],
   "source": [
    "%run -i \"../util/file_utils.ipynb\"\n",
    "%run -i \"../util/util_simple_classifier.ipynb\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "a6155c09-203e-41f3-9994-ad72c1b1aa9b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import re\n",
    "from sklearn.metrics import classification_report\n",
    "from openai import OpenAI\n",
    "client = OpenAI(api_key=OPEN_AI_KEY)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "410bbf87-e0fa-43dd-be78-487793741f54",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Repo card metadata block was not found. Setting CardData to empty.\n",
      "Repo card metadata block was not found. Setting CardData to empty.\n"
     ]
    }
   ],
   "source": [
    "train_dataset = load_dataset(\"SetFit/bbc-news\", split=\"train\")\n",
    "test_dataset = load_dataset(\"SetFit/bbc-news\", split=\"test\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "13e520c1-eb09-4218-9154-3874c7c2bb63",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "carry on star patsy rowlands dies actress patsy rowlands  known to millions for her roles in the carry on films  has died at the age of 71.  rowlands starred in nine of the popular carry on films  alongside fellow regulars sid james  kenneth williams and barbara windsor. she also carved out a successful television career  appearing for many years in itv s well-loved comedy bless this house. rowlands died in hove on saturday morning  her agent said.  born in january 1934  rowlands won a scholarship to the guildhall school of speech and drama scholarship when she was just 15.  after spending several years at the players theatre in london  she made her film debut in 1963 in tom jones  directed by tony richardson. she made her first carry on film in 1969 where she appeared in carry on again doctor. rowlands played the hard-done-by wife or the put-upon employee as a regular carry on star. she also appeared in carry on at your convenience  carry on matron and carry on loving  as well as others.  in recent years she appeared in bbc mini-series the cazalets and played mrs potts in the london stage version of beauty and the beast. agent simon beresford said:  she was just an absolutely favourite client she never complained about anything  particularly when she was ill  she was an old trouper.  she was of the old school - she had skills from musical theatre and high drama  that is why she worked with the great and the good of directors.  she didn t mind always being recognised for the carry on films because she thoroughly enjoyed making them. she was a really lovely person and she will be much missed.  her last appearance on stage was as mrs pearce in the award-winning production of my fair lady at the national theatre. previously married  she leaves one son  alan. her funeral will be a private  family occasion  with a memorial service at a later date.\n",
      "entertainment\n"
     ]
    }
   ],
   "source": [
    "example = test_dataset[0][\"text\"]\n",
    "category = test_dataset[0][\"label_text\"]\n",
    "print(example)\n",
    "print(category)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "e1ebe9f5-68a4-4855-af5e-7043ef7034c1",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "entertainment\n"
     ]
    }
   ],
   "source": [
    "# Run GPT-3 for one example\n",
    "prompt=\"\"\"You are classifying texts by topics. There are 5 topics: tech, entertainment, business, politics and sport. \n",
    "Output the topic and nothing else. For example, if the topic is business, your output should be \"business\".\n",
    "Give the following text, what is its topic from the above list without any additional explanations: \"\"\" + example\n",
    "response = client.chat.completions.create(\n",
    "    model=\"gpt-3.5-turbo\",\n",
    "    temperature=0,\n",
    "    max_tokens=256,\n",
    "    top_p=1.0,\n",
    "    frequency_penalty=0,\n",
    "    presence_penalty=0,\n",
    "    messages=[\n",
    "        {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n",
    "        {\"role\": \"user\", \"content\": prompt}\n",
    "    ], \n",
    ")\n",
    "print(response.choices[0].message.content)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "494814c4-ab0a-4fb6-8681-4051b80d75c4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Evaluate GPT classification on several examples\n",
    "def get_gpt_classification(input_text):\n",
    "    prompt=\"\"\"You are classifying texts by topics. There are 5 topics: tech, entertainment, business, politics and sport. \n",
    "Output the topic and nothing else. For example, if the topic is business, your output should be \"business\".\n",
    "Give the following text, what is its topic from the above list without any additional explanations: \"\"\" + input_text\n",
    "    response = client.chat.completions.create(\n",
    "        model=\"gpt-3.5-turbo\",\n",
    "        temperature=0,\n",
    "        max_tokens=256,\n",
    "        top_p=1.0,\n",
    "        frequency_penalty=0,\n",
    "        presence_penalty=0,\n",
    "        messages=[\n",
    "            {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n",
    "            {\"role\": \"user\", \"content\": prompt}\n",
    "        ], \n",
    "    )\n",
    "    classification = response.choices[0].message.content\n",
    "    classification = classification.lower().strip()\n",
    "    return classification"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "051f81a2-80ae-448e-a5f3-318507ebb968",
   "metadata": {},
   "outputs": [],
   "source": [
    "test_df = test_dataset.to_pandas()\n",
    "test_df.sample(frac=1)\n",
    "test_data = test_df[0:200].copy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "dd1cce33-4049-4c30-b465-9774275a902c",
   "metadata": {},
   "outputs": [],
   "source": [
    "test_data[\"gpt_prediction\"] = test_data[\"text\"].apply(lambda x: get_gpt_classification(x))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "a4b09195-74cd-4e6f-97af-c882a6a402a4",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_one_word_match(input_text):\n",
    "    loc = re.search(r'tech|entertainment|business|sport|politics', input_text).span()\n",
    "    return input_text[loc[0]:loc[1]]\n",
    "test_data[\"gpt_prediction\"] = test_data[\"gpt_prediction\"].apply(lambda x: get_one_word_match(x))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "1e890aea-9205-4393-9a6b-a7a34110ef06",
   "metadata": {},
   "outputs": [],
   "source": [
    "label_list = [\"tech\", \"business\", \"sport\", \"entertainment\", \"politics\"]\n",
    "test_data[\"gpt_label\"] = test_data[\"gpt_prediction\"].apply(lambda x: label_list.index(x))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "72d1fd90-9904-4cb7-a5f5-8ea6bfee333c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "                                                  text  label     label_text  \\\n",
      "0    carry on star patsy rowlands dies actress pats...      3  entertainment   \n",
      "1    sydney to host north v south game sydney will ...      2          sport   \n",
      "2    uk coal plunges into deeper loss shares in uk ...      1       business   \n",
      "3    blair joins school sailing trip the prime mini...      4       politics   \n",
      "4    bath faced with tindall ultimatum mike tindall...      2          sport   \n",
      "..                                                 ...    ...            ...   \n",
      "195  french wine gets 70m euro top-up the french go...      1       business   \n",
      "196  argonaut founder rebuilds empire jez san  the ...      0           tech   \n",
      "197  half-life 2 sweeps bafta awards pc first perso...      0           tech   \n",
      "198  xbox power cable  fire fear  microsoft has sai...      0           tech   \n",
      "199  prop jones ready for hard graft adam jones say...      2          sport   \n",
      "\n",
      "    gpt_prediction  gpt_label  \n",
      "0    entertainment          3  \n",
      "1            sport          2  \n",
      "2         business          1  \n",
      "3         politics          4  \n",
      "4            sport          2  \n",
      "..             ...        ...  \n",
      "195       business          1  \n",
      "196       business          1  \n",
      "197  entertainment          3  \n",
      "198           tech          0  \n",
      "199          sport          2  \n",
      "\n",
      "[200 rows x 5 columns]\n"
     ]
    }
   ],
   "source": [
    "print(test_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "5883f8cf-1e28-4d5f-81c1-2cc7e08a4b70",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "               precision    recall  f1-score   support\n",
      "\n",
      "         tech       0.97      0.80      0.88        41\n",
      "     business       0.87      0.89      0.88        44\n",
      "        sport       1.00      0.96      0.98        48\n",
      "entertainment       0.88      0.90      0.89        40\n",
      "     politics       0.76      0.96      0.85        27\n",
      "\n",
      "     accuracy                           0.90       200\n",
      "    macro avg       0.90      0.90      0.90       200\n",
      " weighted avg       0.91      0.90      0.90       200\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print(classification_report(test_data[\"label\"], test_data[\"gpt_label\"], target_names=label_list))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a7eb7aa4-b846-4a2f-beae-0940fb8936d8",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
